---
title: "Figure-4-and-5"
format: html
editor: visual
---

## Simulation Figures (4 and 5)

```{r}
# =============================================================================
# Imports & Setup
# =============================================================================
library(tidyverse)    # includes dplyr, tidyr, ggplot2, etc.
library(RColorBrewer) # for color palettes
library(cowplot)

# (Optional) Set a classic theme
theme_set(theme_bw())

# =============================================================================
# Reused Function: Transform simulation data from wide to long format
# =============================================================================
transform_sim_data_to_long <- function(sim_data, 
                                       id_vars = c("n_label", "gs_acc", "q_acc")) {
  # The input data frame is assumed to have columns:
  #   so_bias, dsl_bias, so_rmse, dsl_rmse, so_coverage, dsl_coverage
  # and (at least) the id_vars.
  #
  # The target is a long data frame with columns:
  #   id_vars, metric (bias, rmse, coverage), estimator (so, dsl), value
  #
  # First, pivot the six metric columns to long format:
  plot_data <- sim_data %>%
    pivot_longer(
      cols = c("so_bias", "dsl_bias", "so_rmse", "dsl_rmse", "so_coverage", "dsl_coverage"),
      names_to = "variable",
      values_to = "value"
    ) %>%
    # Split the “variable” column (e.g. "so_bias") into estimator and metric.
    separate(variable, into = c("estimator", "metric"), sep = "_") %>%
    # Reorder the columns to: id_vars, metric, estimator, value.
    select(all_of(id_vars), metric, estimator, value)
  
  return(plot_data)
}

# =============================================================================
# ---------------------- Sim0: Test Run (naoki-sim0.csv) -----------------------
# =============================================================================
# (This is a preliminary run.)
sim0_data <- read.csv("~/Downloads/dsl-image-simulation/results/final/naoki-sim0.csv")[, -1]  # drop first column
# Here the id variables are different; note that we use "q_acc" and "n_covariates" instead of gs_acc.
plot_data0 <- transform_sim_data_to_long(sim0_data, 
                                         id_vars = c("q_acc", "n_label", "n_covariates"))
# Capitalize metric names (except for “rmse” which becomes “RMSE”)
plot_data0$metric <- ifelse(plot_data0$metric == "rmse", "RMSE",
                            str_to_title(plot_data0$metric))
# (No plot is generated here; this run is for testing the function.)

# =============================================================================
# ------------------- Simulation 1: Figure 1 -------------------
# "Effect of Surrogate Accuracy on Bias, RMSE, and Coverage 
#  for DSL and SO Estimators"
# =============================================================================

# Read and transform the simulation 1 data
sim1_data <- read.csv("~/Downloads/dsl-image-simulation/results/final/sim1.csv")[, -1]  # drop the first column
plot_data1 <- transform_sim_data_to_long(sim1_data)
plot_data1$metric <- ifelse(plot_data1$metric == "rmse", "RMSE",
                            str_to_title(plot_data1$metric))

# Ensure the key columns are numeric
plot_data1 <- plot_data1 %>%
  mutate(q_acc = as.numeric(as.character(q_acc)),
         gs_acc = as.numeric(as.character(gs_acc)),
         n_label = as.numeric(as.character(n_label)))

# Get unique values (for later use)
q_acc_vals <- sort(unique(plot_data1$q_acc))
n_label_vals <- sort(unique(plot_data1$n_label))



# Define y-axis labels for the three metrics.
yaxis_labels <- c(Bias = "Mean Absolute Bias",
                  RMSE = "Root Mean Squared Error",
                  Coverage = "Nominal Coverage of 95% CI")

# ---------------------------------------------------------------------------
# Prepare color palettes for each estimator.
# For each estimator, the palette is taken from either "Reds" (dsl) or "Blues" (so).
# (In the Python code, n_colors was set to length(q_acc_vals)+1 and then the first color skipped.)
n_q <- length(q_acc_vals)
dsl_palette <- brewer.pal(n_q + 1, "Reds")[-1]   # drop the first (lightest) color
so_palette  <- brewer.pal(n_q + 1, "Blues")[-1]

# Create a combined label that will appear in the legend.
# For example, rows will be labeled as "SO: 0.90" or "DSL: 0.80".
plot_data1 <- plot_data1 %>%
  mutate(estimator = factor(estimator, levels = c("so", "dsl")),
         legend_label = paste0(toupper(estimator), ": ", q_acc))

# Build a named color mapping (one color per estimator & q_acc combination).
color_mapping <- c()
for (est in c("so", "dsl")) {
  for (q in q_acc_vals) {
    label <- paste0(toupper(est), ": ", q)
    if (est == "so") {
      color_mapping[label] <- so_palette[which(q_acc_vals == q)]
    } else {
      color_mapping[label] <- dsl_palette[which(q_acc_vals == q)]
    }
  }
}



# Set marker shapes: for "so" use a square (shape = 15) and for "dsl" use a circle (shape = 16)
shape_mapping <- c("so" = 15, "dsl" = 16)

```

## Figure 4

```{r}
# Subset the data for n_label equal to "500"
plot_data1_sub <- subset(plot_data1, n_label == "500")

p_main <- ggplot(plot_data1_sub, 
                 aes(x = q_acc, 
                     y = value, 
                     group = estimator, 
                     color = estimator)) +
  geom_line() +
  geom_point(aes(shape = estimator), size = 4) +
  facet_wrap(~ metric, scales = "free_y",
             labeller = labeller(metric = function(x) yaxis_labels[x])) +
  labs(x = "Surrogate accuracy", y = "Value") +
  scale_color_manual(
    name = "Estimator", 
    values = c("so" = "royalblue2", "dsl" = "tomato"),
    labels = c("so" = "Surrogate-only", "dsl" = "DSL")
  ) +
  scale_shape_manual(
    name = "Estimator", 
    values = c("so" = 17, "dsl" = 16),  # Adjust shape values as desired
    labels = c("so" = "Surrogate-only", "dsl" = "DSL")
  ) +
  theme_bw(base_size = 12) +
  theme(
    plot.title   = element_text(hjust = 0.5, size = 16),
    axis.text.x  = element_text(size = 11),
    axis.text.y  = element_text(size = 11),
    axis.title   = element_text(size = 14),
    strip.text   = element_text(size = 14),
    legend.position = "right"
  ) +
  # Horizontal dashed line at 0.95 in the Coverage facet
  geom_hline(
    data = subset(plot_data1_sub, metric == "Coverage"), 
    aes(yintercept = 0.95), 
    color = "black", 
    linetype = "dashed",
    inherit.aes = FALSE,   # Not using estimator or color from the main aes
    show.legend = FALSE     # Do not add a legend entry
  ) +
  # Force y-axis in the Coverage facet to go from 0 to 1 by adding invisible points
  geom_blank(
    data = data.frame(
      q_acc   = unique(plot_data1_sub$q_acc),
      value   = c(0, 1),    # Forces coverage range [0,1]
      metric  = "Coverage"  # Only applies to the Coverage facet
    ),
    mapping = aes(x = q_acc, y = value),
    inherit.aes = FALSE,    # Don't use the main aes(color = estimator)
    show.legend = FALSE
  ) +
  # Add a dotted line at y = 0 in the Bias facet
  geom_hline(
    data = subset(plot_data1_sub, metric == "Bias"), 
    aes(yintercept = 0), 
    color = "black", 
    linetype = "dotted",
    inherit.aes = FALSE,   # Not using estimator or color from the main aes
    show.legend = FALSE     # Do not add a legend entry
  )

p_main

```

## Figure 5

```{r}
# ---------------------------
# Read and transform simulation 2 data.
# ---------------------------
sim2_data <- read.csv("~/Downloads/dsl-image-simulation/results/final/sim2.csv")[, -1]
data2 <- transform_sim_data_to_long(sim2_data)

# Subset to those rows where q_acc equals 0.75 and gs_acc >= 0.75.
plot_data2 <- data2 %>% 
  filter(q_acc == 0.75, gs_acc >= 0.75) %>%
  mutate(metric = ifelse(metric == "rmse", "RMSE", str_to_title(metric)),
         q_acc = as.numeric(as.character(q_acc)),
         gs_acc = as.numeric(as.character(gs_acc)),
         n_label = as.numeric(as.character(n_label)))

n_label_vals2 <- sort(unique(plot_data2$n_label))
gs_acc_vals <- sort(unique(plot_data2$gs_acc))
yaxis_labels2 <- c(Bias = "Mean Absolute Bias",
                   RMSE = "Root Mean Squared Error",
                   Coverage = "Nominal Coverage of 95% CI")

# ---------------------------
# Prepare DSL data and color mapping.
# ---------------------------
# For this figure we are only plotting the DSL estimator curves (as a function of gs_acc).
dsl_data <- plot_data2 %>% filter(estimator == "dsl")
dsl_data$gs_acc <- factor(dsl_data$gs_acc, levels = gs_acc_vals)
n_gs <- length(gs_acc_vals)
# Use a Reds palette (again, drop the first/lightest color).
dsl_palette2 <- brewer.pal(n_gs + 1, "Reds")[-1]
gs_color_mapping <- setNames(dsl_palette2, levels(dsl_data$gs_acc))

# ---------------------------
# Compute the SO baseline:
# For estimator "so" at gs_acc == 0.75, take the mean value per metric.
# ---------------------------
so_baseline <- plot_data2 %>% 
  filter(estimator == "so", gs_acc == 0.75) %>%
  group_by(metric) %>%
  summarize(mean_value = mean(value, na.rm = TRUE)) %>%
  ungroup()

# ---------------------------
# Build the plot.
# ---------------------------
```

```{r}
# ---------------------------
# Subset to DSL data for n_label == 500.
# ---------------------------
dsl_data_sub <- dsl_data %>%
  filter(n_label == 500)

# ---------------------------
# Build the plot with gs_acc on the x-axis,
# removing the color legend for gs_acc.
# ---------------------------
p2_new <- ggplot(dsl_data_sub, aes(x = gs_acc, y = value)) +
  # Connect points with a single black line
  geom_line(aes(group = 1), color = "tomato", size = 1) +
  # Plot black points (no color legend)
  geom_point(size = 4, color = "tomato", shape = 16) +
  facet_wrap(
    ~ metric, 
    scales = "free_y",
    labeller = labeller(metric = function(x) yaxis_labels2[x])
  ) +
  labs(x = "Expert accuracy", y = "Value") +
  theme_bw(base_size = 12) +
  theme(
    plot.title    = element_text(hjust = 0.5, size = 16),
    axis.title    = element_text(size = 14),
    axis.text     = element_text(size = 12),
    legend.position = "right",   # Keep legend on the right for the baseline
    legend.title  = element_text(size = 14),
    legend.text   = element_text(size = 12),
    strip.text    = element_text(size = 14)
  ) +
  # ---------------------------
  # SO baseline: shown in the legend
  # ---------------------------
  geom_hline(
    data = so_baseline,
    aes(yintercept = mean_value, linetype = "SO Baseline"),
    color = "blue", size = 1,
    inherit.aes = FALSE
  ) +
  scale_linetype_manual(
    name   = "Baseline",
    values = c("SO Baseline" = "dotdash")
  ) +
  # ---------------------------
  # Coverage: dashed line at 0.95 + force y from 0 to 1
  # ---------------------------
  geom_hline(
    data = dsl_data_sub %>% filter(metric == "Coverage") %>% distinct(metric),
    aes(yintercept = 0.95), 
    color = "black", linetype = "dashed", size = 0.5, 
    inherit.aes = FALSE, show.legend = FALSE
  ) +
  geom_blank(
    data = data.frame(
      gs_acc = unique(dsl_data_sub$gs_acc),
      value  = c(0, 1),
      metric = "Coverage"
    ),
    mapping = aes(x = gs_acc, y = value),
    inherit.aes = FALSE, 
    show.legend = FALSE
  ) +
  # ---------------------------
  # Bias: dotted line at 0
  # ---------------------------
  geom_hline(
    data = dsl_data_sub %>% filter(metric == "Bias") %>% distinct(metric),
    aes(yintercept = 0),
    color = "black", linetype = "dotted", size = 0.5,
    inherit.aes = FALSE, show.legend = FALSE
  )

print(p2_new)

```
